Case Study - Parameter Tuning using Spam Data
1 Code
Use the button to the right to expand the code for reading in the emails and creating the dataset that will be used to model.
#setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
# set directory path
spamPath = "../Data"
# get the directory names
dirNames = list.files(path = paste(spamPath, "messages", sep = .Platform$file.sep))
# create a list of all spam email file names
fullDirNames = paste(spamPath, "messages", dirNames, sep = .Platform$file.sep)
fileNames = list.files(fullDirNames[1], full.names = TRUE)
# split the message into header/body based on the first blank row in the email
splitMessage = function(msg) {
# look for the first blank line
splitPoint = match("", msg)
# header are rows before the first blank line
header = msg[1:(splitPoint-1)]
# body are rows after the first blank line
body = msg[ -(1:splitPoint) ]
return(list(header = header, body = body))
}
# get boundary positions for attachment
getBoundary = function(header) {
boundaryIdx = grep("boundary=", header)
boundary = gsub('"', "", header[boundaryIdx])
gsub(".*boundary= *([^;]*);?.*", "\\1", boundary)
}
# process the header into usable arguments
processHeader = function(header) {
# modify the first line to create a key:value pair
header[1] = sub("^From", "Top-From:", header[1])
# reads data in the format key: value and accounts for continuation on subsequent lines
headerMat = read.dcf(textConnection(header), all = TRUE)
# convert to character vector with key as the name for each value
headerVec = unlist(headerMat)
dupKeys = sapply(headerMat, function(x) length(unlist(x)))
names(headerVec) = rep(colnames(headerMat), dupKeys)
return(headerVec)
}
# process the attachment
processAttach = function(body, contentType){
n = length(body)
boundary = getBoundary(contentType)
bString = paste("--", boundary, sep = "")
bStringLocs = which(bString == body)
eString = paste("--", boundary, "--", sep = "")
eStringLoc = which(eString == body)
if (length(eStringLoc) == 0) eStringLoc = n
if (length(bStringLocs) <= 1) {
attachLocs = NULL
msgLastLine = n
if (length(bStringLocs) == 0) bStringLocs = 0
} else {
attachLocs = c(bStringLocs[ -1 ], eStringLoc)
msgLastLine = bStringLocs[2] - 1
}
msg = body[ (bStringLocs[1] + 1) : msgLastLine]
if ( eStringLoc < n )
msg = c(msg, body[ (eStringLoc + 1) : n ])
if ( !is.null(attachLocs) ) {
attachLens = diff(attachLocs, lag = 1)
attachTypes = mapply(function(begL, endL) {
CTloc = grep("^[Cc]ontent-[Tt]ype", body[ (begL + 1) : (endL - 1)])
if ( length(CTloc) == 0 ) {
MIMEType = NA
} else {
CTval = body[ begL + CTloc[1] ]
CTval = gsub('"', "", CTval )
MIMEType = sub(" *[Cc]ontent-[Tt]ype: *([^;]*);?.*", "\\1", CTval)
}
return(MIMEType)
}, attachLocs[-length(attachLocs)], attachLocs[-1])
}
if (is.null(attachLocs)) return(list(body = msg, attachDF = NULL) )
return(list(body = msg,
attachDF = data.frame(aLen = attachLens,
aType = unlist(attachTypes),
stringsAsFactors = FALSE)))
}
readEmail = function(dirName) {
# retrieve the names of files in directory
fileNames = list.files(dirName, full.names = TRUE)
# drop files that are not email
notEmail = grep("cmds$", fileNames)
if ( length(notEmail) > 0) fileNames = fileNames[ - notEmail ]
# read all files in the directory
lapply(fileNames, readLines, encoding = "latin1")
}
# process all emails, including headers and attachments
processAllEmail = function(dirName, isSpam = FALSE)
{
# read all files in the directory
messages = readEmail(dirName)
fileNames = names(messages)
n = length(messages)
# split header from body
eSplit = lapply(messages, splitMessage)
rm(messages)
# process header as named character vector
headerList = lapply(eSplit, function(msg)
processHeader(msg$header))
# extract content-type key
contentTypes = sapply(headerList, function(header)
header["Content-Type"])
# extract the body
bodyList = lapply(eSplit, function(msg) msg$body)
rm(eSplit)
# which email have attachments
hasAttach = grep("^ *multi", tolower(contentTypes))
# get summary stats for attachments and the shorter body
attList = mapply(processAttach, bodyList[hasAttach],
contentTypes[hasAttach], SIMPLIFY = FALSE)
bodyList[hasAttach] = lapply(attList, function(attEl)
attEl$body)
attachInfo = vector("list", length = n )
attachInfo[ hasAttach ] = lapply(attList,
function(attEl) attEl$attachDF)
# prepare return structure
emailList = mapply(function(header, body, attach, isSpam) {
list(isSpam = isSpam, header = header,
body = body, attach = attach)
},
headerList, bodyList, attachInfo,
rep(isSpam, n), SIMPLIFY = FALSE )
names(emailList) = fileNames
invisible(emailList)
}
# apply the over-arching function to all emails
emailStruct = mapply(processAllEmail, fullDirNames, isSpam = rep( c(FALSE, TRUE), 3:2))
emailStruct = unlist(emailStruct, recursive = FALSE)
# save the processed email file
save(emailStruct, file="emailXX.rda")
# create a derived function that allows the code to remain virtually unchanged when adding/removing functions
createDerivedDF = function(email = emailStruct, operations = funcList, verbose = FALSE) {
els = lapply(names(operations), function(id) {
if(verbose) print(id)
e = operations[[id]]
v = if(is.function(e))
sapply(email, e)
else
sapply(email, function(msg) eval(e))
v
})
df = as.data.frame(els)
names(df) = names(operations)
invisible(df)
}
# create list of functions that can be applied to emails
funcList = list(
isSpam =
expression(msg$isSpam)
,
# is the message a response from an earlier message
isRe =
function(msg) {
# Can have a Fwd: Re: ... but we are not looking for this here.
# We may want to look at In-Reply-To field.
"Subject" %in% names(msg$header) &&
length(grep("^[ \t]*Re:", msg$header[["Subject"]])) > 0
}
,
numLines =
function(msg) length(msg$body)
,
bodyCharCt =
function(msg)
sum(nchar(msg$body))
,
underscore =
function(msg) {
if(!"Reply-To" %in% names(msg$header))
return(FALSE)
txt <- msg$header[["Reply-To"]]
length(grep("_", txt)) > 0 &&
length(grep("[0-9A-Za-z]+", txt)) > 0
}
,
subExcCt =
function(msg) {
x = msg$header["Subject"]
if(length(x) == 0 || sum(nchar(x)) == 0 || is.na(x))
return(NA)
sum(nchar(gsub("[^!]","", x)))
}
,
subQuesCt =
function(msg) {
x = msg$header["Subject"]
if(length(x) == 0 || sum(nchar(x)) == 0 || is.na(x))
return(NA)
sum(nchar(gsub("[^?]","", x)))
}
,
numAtt =
function(msg) {
if (is.null(msg$attach)) return(0)
else nrow(msg$attach)
}
,
priority =
function(msg) {
ans <- FALSE
# Look for names X-Priority, Priority, X-Msmail-Priority
# Look for high any where in the value
ind = grep("priority", tolower(names(msg$header)))
if (length(ind) > 0) {
ans <- length(grep("high", tolower(msg$header[ind]))) >0
}
ans
}
,
numRec =
function(msg) {
# unique or not.
els = getMessageRecipients(msg$header)
if(length(els) == 0)
return(NA)
# Split each line by "," and in each of these elements, look for
# the @ sign. This handles
tmp = sapply(strsplit(els, ","), function(x) grep("@", x))
sum(sapply(tmp, length))
}
,
perCaps =
function(msg)
{
body = paste(msg$body, collapse = "")
# Return NA if the body of the message is "empty"
if(length(body) == 0 || nchar(body) == 0) return(NA)
# Eliminate non-alpha characters and empty lines
body = gsub("[^[:alpha:]]", "", body)
els = unlist(strsplit(body, ""))
ctCap = sum(els %in% LETTERS)
100 * ctCap / length(els)
}
,
isInReplyTo =
function(msg)
{
"In-Reply-To" %in% names(msg$header)
}
,
sortedRec =
function(msg)
{
ids = getMessageRecipients(msg$header)
all(sort(ids) == ids)
}
,
subPunc =
function(msg)
{
if("Subject" %in% names(msg$header)) {
el = gsub("['/.:@-]", "", msg$header["Subject"])
length(grep("[A-Za-z][[:punct:]]+[A-Za-z]", el)) > 0
}
else
FALSE
},
hour =
function(msg)
{
date = msg$header["Date"]
if ( is.null(date) ) return(NA)
# Need to handle that there may be only one digit in the hour
locate = regexpr("[0-2]?[0-9]:[0-5][0-9]:[0-5][0-9]", date)
if (locate < 0)
locate = regexpr("[0-2]?[0-9]:[0-5][0-9]", date)
if (locate < 0) return(NA)
hour = substring(date, locate, locate+1)
hour = as.numeric(gsub(":", "", hour))
locate = regexpr("PM", date)
if (locate > 0) hour = hour + 12
locate = regexpr("[+-][0-2][0-9]00", date)
if (locate < 0) offset = 0
else offset = as.numeric(substring(date, locate, locate + 2))
(hour - offset) %% 24
}
,
multipartText =
function(msg)
{
if (is.null(msg$attach)) return(FALSE)
numAtt = nrow(msg$attach)
types =
length(grep("(html|plain|text)", msg$attach$aType)) > (numAtt/2)
}
,
hasImages =
function(msg)
{
if (is.null(msg$attach)) return(FALSE)
length(grep("^ *image", tolower(msg$attach$aType))) > 0
}
,
isPGPsigned =
function(msg)
{
if (is.null(msg$attach)) return(FALSE)
length(grep("pgp", tolower(msg$attach$aType))) > 0
},
perHTML =
function(msg)
{
if(! ("Content-Type" %in% names(msg$header))) return(0)
el = tolower(msg$header["Content-Type"])
if (length(grep("html", el)) == 0) return(0)
els = gsub("[[:space:]]", "", msg$body)
totchar = sum(nchar(els))
totplain = sum(nchar(gsub("<[^<]+>", "", els )))
100 * (totchar - totplain)/totchar
},
subSpamWords =
function(msg)
{
if("Subject" %in% names(msg$header))
length(grep(paste(SpamCheckWords, collapse = "|"),
tolower(msg$header["Subject"]))) > 0
else
NA
}
,
subBlanks =
function(msg)
{
if("Subject" %in% names(msg$header)) {
x = msg$header["Subject"]
# should we count blank subject line as 0 or 1 or NA?
if (nchar(x) == 1) return(0)
else 100 *(1 - (nchar(gsub("[[:blank:]]", "", x))/nchar(x)))
} else NA
}
,
noHost =
function(msg)
{
# Or use partial matching.
idx = pmatch("Message-", names(msg$header))
if(is.na(idx)) return(NA)
tmp = msg$header[idx]
return(length(grep(".*@[^[:space:]]+", tmp)) == 0)
}
,
numEnd =
function(msg)
{
# If we just do a grep("[0-9]@", )
# we get matches on messages that have a From something like
# " \"marty66@aol.com\" <synjan@ecis.com>"
# and the marty66 is the "user's name" not the login
# So we can be more precise if we want.
x = names(msg$header)
if ( !( "From" %in% x) ) return(NA)
login = gsub("^.*<", "", msg$header["From"])
if ( is.null(login) )
login = gsub("^.*<", "", msg$header["X-From"])
if ( is.null(login) ) return(NA)
login = strsplit(login, "@")[[1]][1]
length(grep("[0-9]+$", login)) > 0
},
# identifies if all alpha characters in the subject line are upper case
isYelling =
function(msg)
{
if ( "Subject" %in% names(msg$header) ) {
el = gsub("[^[:alpha:]]", "", msg$header["Subject"])
if (nchar(el) > 0) nchar(gsub("[A-Z]", "", el)) < 1
else FALSE
}
else
NA
},
forwards =
function(msg)
{
x = msg$body
if(length(x) == 0 || sum(nchar(x)) == 0)
return(NA)
ans = length(grep("^[[:space:]]*>", x))
100 * ans / length(x)
},
isOrigMsg =
function(msg)
{
x = msg$body
if(length(x) == 0) return(NA)
length(grep("^[^[:alpha:]]*original[^[:alpha:]]+message[^[:alpha:]]*$",
tolower(x) ) ) > 0
},
isDear =
function(msg)
{
x = msg$body
if(length(x) == 0) return(NA)
length(grep("^[[:blank:]]*dear +(sir|madam)\\>",
tolower(x))) > 0
},
isWrote =
function(msg)
{
x = msg$body
if(length(x) == 0) return(NA)
length(grep("(wrote|schrieb|ecrit|escribe):", tolower(x) )) > 0
},
avgWordLen =
function(msg)
{
txt = paste(msg$body, collapse = " ")
if(length(txt) == 0 || sum(nchar(txt)) == 0) return(0)
txt = gsub("[^[:alpha:]]", " ", txt)
words = unlist(strsplit(txt, "[[:blank:]]+"))
wordLens = nchar(words)
mean(wordLens[ wordLens > 0 ])
}
,
numDlr =
function(msg)
{
x = paste(msg$body, collapse = "")
if(length(x) == 0 || sum(nchar(x)) == 0)
return(NA)
nchar(gsub("[^$]","", x))
}
)
# spam words
SpamCheckWords = c("viagra", "pounds", "free", "weight", "guarantee", "million", "dollars", "credit", "risk", "prescription", "generic", "drug",
"financial", "save", "dollar", "erotic", "million", "barrister", "beneficiary", "easy", "money back", "money", "credit card")
getMessageRecipients = function(header) {
c(if("To" %in% names(header)) header[["To"]] else character(0),
if("Cc" %in% names(header)) header[["Cc"]] else character(0),
if("Bcc" %in% names(header)) header[["Bcc"]] else character(0)
)
}
# add in additional derived variables
emailDF = createDerivedDF(emailStruct)
save(emailDF, file="emailDFknit.rda")Use the buttons to the right to expand the code for recursive partitioning using caret with method=rpart from R.
load(file = "emailDFknit.rda")
### recursive partitioning ###
library(rpart)
library(caret)
# Ok so first of all our data is in T/F 'factors'. We need to change it to
# numbers. And as it turns out, there are quite a few NANs as well. Let's set
# those to zero.
setupRnum = function(data) {
logicalVars = which(sapply(data, is.logical))
facVars = lapply(data[, logicalVars], function(x) {
x = as.numeric(x)
})
cbind(facVars, data[, -logicalVars])
}
emailDFnum = setupRnum(emailDF)
emailDFnum[is.na(emailDFnum)] <- 0
# save(emailDFnum,file='emailDFnum.rda')
# Because our authors prefer Type I/II errors, but the cool kids know that
# precision/recall/F1 is where its at, while the default of caret is accuracy and
# kappa. To get us all on the same page, I create a function that returns the
# metrics we want. However, rather than re-invent the wheel, I just install a
# package. I am not sure if it had Type I/II errors so those I made my self.
library(MLmetrics)
f1 <- function(data, lev = NULL, model = NULL) {
f1_val <- F1_Score(y_pred = data$pred, y_true = data$obs, positive = lev[1])
acc <- Accuracy(y_pred = data$pred, y_true = data$obs)
p <- Precision(y_pred = data$pred, y_true = data$obs, positive = lev[1])
r <- Recall(y_pred = data$pred, y_true = data$obs, positive = lev[1])
fp <- sum(data$pred == 0 & data$obs == 1)/length(data$pred)
fn <- sum(data$pred == 1 & data$obs == 0)/length(data$pred)
c(prec = p, F1 = f1_val, rec = r, Type_I_err = fp, Type_II_err = fn)
}
library(naivebayes)
library(e1071)
library(rpart)
### rpart with defaults ###
train_control <- trainControl(method = "cv", number = 5, savePredictions = "final",
summaryFunction = f1)
set.seed(1234)
model_rpart_defaults <- caret::train(as.factor(isSpam) ~ ., data = emailDFnum, trControl = train_control,
method = "rpart", na.action = na.omit)
# create the dataset used to loop through different values of minsplit and
# maxdepth
control_grid <- expand.grid(minsplit = seq(5, 25, 1), maxdepth = seq(15, 30, 1))
# create the grid used inside caret::train to grid search on cp
cart_grid <- expand.grid(cp = seq(from = 0, to = 0.1, by = 0.001))
# train function for 5-fold cross validation
train_control <- trainControl(method = "cv", number = 5, savePredictions = "final",
summaryFunction = f1)# loop to tune cp, minsplit and maxdepth
for (i in 1:nrow(control_grid)) {
set.seed(1234)
model_rpart <- caret::train(as.factor(isSpam) ~ ., data = emailDFnum, trControl = train_control,
method = "rpart", control = rpart.control(minsplit = control_grid$minsplit[i],
maxdepth = control_grid$maxdepth[i]), tuneGrid = cart_grid, na.action = na.omit)
# identify which cp parameter for the loop was the best
results = as.data.frame(model_rpart$results)
results$minsplit = control_grid$minsplit[i]
results$maxdepth = control_grid$maxdepth[i]
if (i != 1) {
final_result <- rbind(final_result, results)
} else {
final_result <- results
}
# print(paste0(round(i/nrow(control_grid),4)*100,'% complete',sep=''))
}load(file = "emailDFnum.rda")
load(file = "final_result_knit.rda")
results_tuned = final_result[which(final_result$prec == max(final_result$prec)),
]
# rerun the model with the winning parameters to enable tree plotting
set.seed(1234)
model_rpart_tuned <- caret::train(as.factor(isSpam) ~ ., data = emailDFnum, trControl = train_control,
method = "rpart", control = rpart.control(minsplit = results_tuned$minsplit,
maxdepth = results_tuned$maxdepth), tuneGrid = cart_grid, na.action = na.omit)2 Introduction
Many machine learning packages come with convenient features, including useful default parameters for many algorithms. While the default parameters are often good, they can be tailored to the specific problem being analyzed. In this case study, we will optimize the rpart decision tree algorithm to classify Spam messages from the SpamAssassin corpus. We will perform a grid search across multiple parameters to minimize the error on our dataset. We will also show visualizations to show how our model metrics improve over different values of these parameters.
We sourced the data from the Data Science in R website, but the corpus is also available at the SpamAssassin website
3 Question
Question 19: Consider the other parameters that can be used to control the recursive partitioning process. Read the documentation for them in the rpart.control documentation. Also, carry out an Internet search for more information on how to tweak the rpart tuning parameters. Experiment with values for these parameters. Do the trees that result make sense with your understanding of how the parameters are used? Can you improve the prediction using them?
4 Methods
4.1 Dataset
For this problem, we will be using the Spam Assassin Corpus which provides emails that can be used to train spam filters. A spam filter tries to judge whether or not an email is spam based on the contents and characteristics of an email. This corpus contains over nine thousand emails, pre-identified as spam or not-spam. There are five directories of messages: 3 folders for not spam and 2 for spam. Each folder contains full email messages including their header, body, and attachments. The header contains information about the recipients as well as other meta-data about the email. Content-Type, for instance, will help us identify if the message has an attachment. This information is typically in a “key: value” format which will be useful for parsing the information. The body contains the actual message as well as any attachments. We will run code to separate these elements and then create features that might be useful for identification of spam vs. not spam.
4.2 Dataset Preparation
Our main goals in preparing each message is to separate the header from the body, while also dealing with any attachments present. We will primarily use certain identifiable aspects of the email in the classification model, but first we need to get the emails into a form where we can derive features from them. There are several functions that aid in this process:
splitMessage is a function to assist in splitting the header from the body. It looks for the first blank line which is the boundary the two main parts of the email. It uses this index to assign each to their own object then returns them as a list.
getBoundary reads the header to determine the boundary that separates any attachments from the body. It uses regular expressions to search the header and determine the boundary.
processHeader is used to extract the “key: value” pairs from the header for later use.
processAttach separates out the attachment from the body with the help of the getBoundary function. It returns the body and some details about the attachment as a list.
readEmail this function reads the messages into R from their locations on the hard drive.
processAllEmail finally this function uses the functions above to separate each email message into its major components so we can derive features for classification from them.
4.3 Feature Creation
createDerivedDF uses a list of functions to derive features for classification. This holder function allows us to easily add and remove functions for creating features without having to readjust our code each time.
Below is a description of each derived feature. This information is also available in Table 3.1 in Nolan and Lang.
isSpam is a True / False flag if a message is spam. This will be the dependent variable for modeling.
isRe is a True / False flag if a message is a reply to an earlier message.
numLines is an integer representing the number of lines in the message.
bodyCharCt is an integer count of the number of characters in the message.
underscore is a True / False flag that indicates if there is an underscore in the From field of the header.
subExcCt is a count of exclamation points in the subject line.
subQuesCt is a count of question marks in the subject line.
numAtt is a count of the attachments in a message.
priority is a True / False flag if there is a priority key in the header.
numRec a count of the recipients of a message, including CCs.
perCaps the percentage of capital letters in the message body, not including attachments.
isInReplyTo is a True / False flag if there is a In-Reply-To key in the header.
sortedRec is a True / False flag if the recipient addresses are sorted.
subPunc is a True / False flag indicating if the subject has numbers or punctuation embedded in it.
hour hour of the day in the date field.
multipartText is a True / False flag if the MIME type is multipart/text.
hasImages is a True / False flag if the message contains images.
isPGPsigned True if the message contains a PGP signature.
perHTML percentage of characters in HTML tags vs. all characters in the message.
subSpamWords True if the subject contains words from the spam list.
subBlanks percentage of blanks in the subject.
noHost True if there is no hostname in the Message-Id key in the header.
numEnd True if the email senders name ends in a number (before the @).
isYelling True if the subject is in all capital letters.
forwards number of forward symbols in the message body.
isOrigMsg True if the message body contains the words “original message”.
isDear True if the message body contains the word “dear”.
isWrote True if the message body contains the word “wrote:”.
avgWordLen The average length of words in the message.
numDlr number of dollar signs in the message body.
5 Modeling
5.1 CART
For our models, we use the caret package, which contains many models available for training using its train function. The model we will use to predict whether an email is spam or not-spam is CART (Classification and Regression Trees). The CART algorithm can be visualized as a branching tree with leaves or “nodes” representing the algorithm’s predictions and the branches are calculated splits based on features in the model’s feature space, which, in our problem, different features of emails, as previously described in the Methods section. The algorithm repeatedly splits each feature to find the split which results in the greatest information gain. We give a simple example of one possible CART below.
Figure 1shows the resulting CART tree for the spam data, usingmaxdepth=2. In this model, the featureperCapswas determined to be the best initial split, or root node, of the data between non-spam (0) and spam (1).perCaps < 13checks the email message for whether the percentage of capital letters in the email is less than 13%. If it is, then it goes to the left node. Else, it goes to the right node.Left node -
perHTML < 3.9: In addition to having less than 13% capitalized letters, if the message characters consist of less than 3.9% HTML tags, then we classify the message as non-spam. Otherwise, we classify the message as spam.Right node -
numLines < 9.5: In addition to having greater than or equal to 13% capitalized letters, if the message has less than 9.5 lines, then we classify the message as non-spam. Otherwise, we classify the message as spam.
# This is a basic example of a CART with maxdepth=2
set.seed(123)
model_rpart_example <- caret::train(as.factor(isSpam) ~ ., data = emailDFnum, trControl = train_control,
method = "rpart", control = rpart.control(maxdepth = 2), tuneGrid = cart_grid,
na.action = na.omit)
# Let us visualize the tree
fancyRpartPlot(model_rpart_example$finalModel, sub = "")We use the rpart package and rpart method in caret to implement CART in this case study.
5.2 Parameter Tuning
We can potentially improve upon the model above by tuning the parameters of the model. We took three different parameters–cp, minsplit, and maxdepth–and generated different CART models for each combination of unique values those parameters could take (see Table 1).
Table 1. Parameters of rpart::rpart.control and caret::train and how they were tuned.
| Parameter | Description | Default Value | Tuning Method |
|---|---|---|---|
cp |
The complexity parameter is a threshold CART computes for each potential split of the tree. If the addition of another node increases the complexity of the tree by a value larger than cp, then the node does not get added. Hence, low values of cp may lead to overfitting. |
caret::train will iterate over various values of cp to find the optimal model, while the default in rpart::rpart is 0.01 |
seq(from = 0, to=0.01, by=0.001) |
minsplit |
The minimum split parameter controls the minimum number of observations required to split a node. We started with a value of 5 rather than the absolute minimum, 1, because lower values of minsplit lead to overfitting. |
20 |
seq(5,25,1) |
maxdepth |
The maximum depth parameter controls the maximum depth of the tree. Depending on the number of observations in the dataset, deeper trees could lead to overfitting. | 30 |
seq(15,30,1) |
There are certain parameters we did not tune as explained in Table 2 below.
Table 2. Parameters of rpart::rpart.control that we did not explicitly tune.
| Parameter | Description | Default Value |
|---|---|---|
minbucket |
The minimum bucket value controls the minimum number of observations allowed in a terminal node. Given that the default value of this parameter is by default determined by the minsplit parameter, we only tuned minsplit. |
round(minsplit/3) |
usesurrogate, maxsurrogate |
The parameters that control how surrogates behave only apply when there are missing values in the data. Our data does not contain missing values, and thus those parameters are irrelevant. | - |
5.3 Evaluation Metric
Several evaluation metrics can be used to evaluate performance for a classification algorithm:
- Accuracy - ratio of correctly classified observations to the total observations
- Precision - ratio of correctly predicted spam emails to the total predicted spam emails
- Recall - ratio of correctly predicted spam emails to the total number of actual spam emails
- F1 - weighted average of precision and recall that looks at both false positives and false negatives
- Type I Errors - incorrect identification of an email as spam
- Type II Errors - incorrect identification of an email as not spam
There are benefits and downsides to the various evaluation metrics listed above. Accuracy is easy to understand, but it treats false positives (incorrectly identifying an email as spam) and false negatives (failing to identify an email as spam) equally. For spam, we assume that people would prefer to get the occasional spam email instead of potentially not seeing an important email because it is misclassified as spam. Therefore, our primary objective is to maximize the precision of spam.
5.4 Effects of Parameter Tuning
The three graphs below show the effects of our parameter tuning. We compare the effects of different parameter values using a 3D mesh. Higher precision is better, so peaks in the 3D surface represent the best values of precision for a given combination of input parameters. We could use a 2D line to represent the best value of a parameter in isolation, but the 3D surface allows us to see the interaction between parameters. Since we have 3 parameters we are tuning, there will be multiple observations for any given combination of the two factors in the 3D plot. We take the average precision of the points for plotting purposes. Please note that Figure 2, Figure 3, and Figure 4 are interactive and you can move the plots around.
In Figure 2, we compare complexity and minsplit. For minsplit values less than 25, there is a plateau in precision between cp=0.023 to cp=0.087. Precision rapidly increases from cp=0.09 to cp=0.08, and again from cp=0.02 to cp=0. For any given value of cp, changing minsplit doesn’t seem to have much effect.
plot_data <- final_result %>% group_by(cp, minsplit) %>% summarise_all(mean) %>%
select(prec, cp, minsplit)
# volcano is a numeric matrix that ships with R
fig <- plot_ly(z = plot_data$prec, x = plot_data$cp, y = plot_data$minsplit, type = "mesh3d",
opacity = 0.7)
fig <- fig %>% layout(scene = list(xaxis = list(title = "Complexity"), yaxis = list(title = "Minsplit"),
zaxis = list(title = "Precision")))
figIn Figure 3, we compare complexity and maxdepth. For maxdepth values less than 30, there is a plateau in precision from cp=0.08 to cp=0.023. Precision rapidly increases from cp=0.01 to cp=0.08, and again from cp=0.023 to cp=0. For any given value of cp, changing maxdepth doesn’t seem to have much effect.
plot_data <- final_result %>% group_by(cp, maxdepth) %>% summarise_all(mean) %>%
select(prec, cp, maxdepth)
fig2 <- plot_ly(z = plot_data$prec, x = plot_data$cp, y = plot_data$maxdepth, type = "mesh3d",
opacity = 0.7)
fig2 <- fig2 %>% layout(scene = list(xaxis = list(title = "Complexity"), yaxis = list(title = "Depth"),
zaxis = list(title = "Precision")))
fig2In Figure 4 we compare different levels of maxdepth and minsplit. Here we can see the highest precision values when minsplit is low and maxdepth is high. When minsplit is in the 15-25 range, we see larger increases in precision when maxdepth increases from 15 to 20. Past maxdepth = 20, the increases to precision are small for minsplit 15-25. For minsplit 5-15, precision increases across the whole range of depths we tested. Decreases in split generally result in higher precision across different depths, although we see 3 peaks/troughs leading up to the maximum precision.
plot_data <- final_result %>% group_by(minsplit, maxdepth) %>% summarise_all(mean) %>%
select(prec, minsplit, maxdepth)
fig3 <- plot_ly(z = plot_data$prec, x = plot_data$minsplit, y = plot_data$maxdepth,
type = "mesh3d", opacity = 0.7)
fig3 <- fig3 %>% layout(scene = list(xaxis = list(title = "Split"), yaxis = list(title = "Depth"),
zaxis = list(title = "Precision")))
fig36 Results
336 different combinations of minsplit and maxdepth were passed to the the caret::train function. The rpart method was able to tune cp independently. 101 different values of cp were evaluated for each combination of minsplit and maxdepth. This totals 33,936 combinations of minsplit, maxdepth and cp that were evaluated. CP appears to be the most important parameter to tune, leading to the largest gains in precision. Tuning max depth and minsplit will also help, but the impact is far less.
The winning model with the highest precision used the following parameter values: cp = 0, minsplit = 6, and maxdepth = 29. The resulting precision is 96.75%. We compared this to the model using the default values for parameters: cp = 0.07 (cp was tuned by default as part of rpart in caret::train), minsplit = 20, and maxdepth = 30. The resulting precision for the default model was 86.84%. Tuning parameters increased precision from 86.84% to 96.75%.
Figure 5 below shows the CART tree for the model using default values for the control parameters. The first split on the tree is perCaps < 13. The second split on the tree in perHTML < 3.9.
# tree with defaults
par(mar = c(5, 4, 10, 2))
fancyRpartPlot(model_rpart_defaults$finalModel, sub = "")The CART tree for the winning model is shown below in Figure 6. We display this plot only to show the difference in complexity and the increase in maxdepth. The maxdepth makes this tree not easy to visualize.
Printing the output from Figure 6 is also quite verbose and we have chosen not to print it here. However, the first split on perCaps and the subsequent split for each branch are detailed below:
- Root Node
perCaps< 12.86: There are 8038 emails in this node. 1391 emails (17%) are incorrectly classified as spam.perHTML< 3.93: There are 7401 emails in this node. 970 emails (13%) are incorrectly classified as spam.perHTML>= 3.93 There are 637 emails in this node. 216 (34%) emails are incorrectly classified as not spam.
perCaps>= 12.86: There are 1310 emails in this node. 304 emails (23%) are incorrectly classified as not spam.numLines< 9.5: There are 138 emails in this node. 17 emails (12%) are incorrectly classified as spam.- `numLines’ >= 9.5: There are 1172 emails in this node. 183 emails (16%) are incorrectly classified as not spam.
In conclusion, our research of the different parameters available for tuning the CART algorithm–specifically, minsplit,maxdepth, and cp–led us to create a better model for classifying email messages as spam or non-spam. Tuning these model parameters significantly improved precision, from 86.84% to 96.75%. Additionally, our winning tree makes sense in terms of the resulting winning parameters. Our winning tree is relatively deep, with a low minimum split value and a low CP threshold. Those qualities can lead to complex trees that perform well on training data, as our winning model did. The values for the tuned parameters could be cause for concern in regards to overfitting. However, 5-fold cross validation was used in the model, which helps prevent against overfitting. The resulting tree is very deep and complex. The vast amount of data used to train the model makes this complexity possible to improve the model’s precision.
7 References
(http://rdatasciencecases.org/) Case Studies in Data Science with R by Nolan and Lang. (https://spamassassin.apache.org/downloads.cgi?update=20200129223) SpamAssassin Race website.